{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generating question answer pairs from text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import importlib\n",
    "import os\n",
    "\n",
    "from keras.layers import Input, Embedding, GRU, Bidirectional, Dense, Lambda\n",
    "from keras.models import Model, load_model\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from keras.optimizers import Adam\n",
    "import keras.backend as K\n",
    "from keras.utils import plot_model\n",
    "\n",
    "import numpy as np\n",
    "import random\n",
    "import pickle as pkl\n",
    "\n",
    "from utils.write import training_data, test_data, collapse_documents, expand_answers, _read_data, glove\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run params\n",
    "SECTION = 'write'\n",
    "RUN_ID = '0001'\n",
    "DATA_NAME = 'qa'\n",
    "RUN_FOLDER = 'run/{}/'.format(SECTION)\n",
    "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n",
    "\n",
    "if not os.path.exists(RUN_FOLDER):\n",
    "    os.mkdir(RUN_FOLDER)\n",
    "    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))\n",
    "    os.mkdir(os.path.join(RUN_FOLDER, 'images'))\n",
    "    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))\n",
    "\n",
    "mode =  'build' #'load' #"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "training_data_gen = training_data()\n",
    "# training_data_gen = [next(training_data_gen)]\n",
    "test_data_gen = test_data()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = next(training_data_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 0\n",
    "\n",
    "print('document_tokens\\n', t['document_tokens'][idx])\n",
    "print('\\n')\n",
    "print('question_input_tokens\\n', t['question_input_tokens'][idx])\n",
    "print('\\n')\n",
    "print('answer_masks\\n', t['answer_masks'][idx])\n",
    "print('\\n')\n",
    "print('answer_labels\\n', t['answer_labels'][idx])\n",
    "print('\\n')\n",
    "print('question_output_tokens\\n', t['question_output_tokens'][idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# GloVe\n",
    "\n",
    "VOCAB_SIZE = glove.shape[0]\n",
    "EMBEDDING_DIMENS = glove.shape[1]\n",
    "\n",
    "print('GLOVE')\n",
    "print('VOCAB_SIZE: ', VOCAB_SIZE)\n",
    "print('EMBEDDING_DIMENS: ', EMBEDDING_DIMENS)\n",
    "\n",
    "GRU_UNITS = 100\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_DOC_SIZE = None\n",
    "MAX_ANSWER_SIZE = None\n",
    "MAX_Q_SIZE = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "document_tokens = Input(shape=(MAX_DOC_SIZE,), name=\"document_tokens\")\n",
    "\n",
    "embedding = Embedding(input_dim = VOCAB_SIZE, output_dim = EMBEDDING_DIMENS, weights=[glove], mask_zero = True, name = 'embedding')\n",
    "document_emb = embedding(document_tokens)\n",
    "\n",
    "answer_outputs = Bidirectional(GRU(GRU_UNITS, return_sequences=True), name = 'answer_outputs')(document_emb)\n",
    "answer_tags = Dense(2, activation = 'softmax', name = 'answer_tags')(answer_outputs)\n",
    "\n",
    "encoder_input_mask = Input(shape=(MAX_ANSWER_SIZE, MAX_DOC_SIZE), name=\"encoder_input_mask\")\n",
    "encoder_inputs = Lambda(lambda x: K.batch_dot(x[0], x[1]), name=\"encoder_inputs\")([encoder_input_mask, answer_outputs])\n",
    "encoder_cell = GRU(2 * GRU_UNITS, name = 'encoder_cell')(encoder_inputs)\n",
    "\n",
    "decoder_inputs = Input(shape=(MAX_Q_SIZE,), name=\"decoder_inputs\")\n",
    "decoder_emb = embedding(decoder_inputs)\n",
    "decoder_emb.trainable = False\n",
    "decoder_cell = GRU(2 * GRU_UNITS, return_sequences = True, name = 'decoder_cell')\n",
    "decoder_states = decoder_cell(decoder_emb, initial_state = [encoder_cell])\n",
    "\n",
    "decoder_projection = Dense(VOCAB_SIZE, name = 'decoder_projection', activation = 'softmax', use_bias = False)\n",
    "decoder_outputs = decoder_projection(decoder_states)\n",
    "\n",
    "total_model = Model([document_tokens, decoder_inputs, encoder_input_mask], [answer_tags, decoder_outputs])\n",
    "plot_model(total_model, to_file='model.png',show_shapes=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder_emb.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "answer_model = Model(document_tokens, [answer_tags])\n",
    "decoder_initial_state_model = Model([document_tokens, encoder_input_mask], [encoder_cell])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### INFERENCE MODEL ####\n",
    "\n",
    "decoder_inputs_dynamic = Input(shape=(1,), name=\"decoder_inputs_dynamic\")\n",
    "decoder_emb_dynamic = embedding(decoder_inputs_dynamic)\n",
    "decoder_init_state_dynamic = Input(shape=(2 * GRU_UNITS,), name = 'decoder_init_state_dynamic') #the embedding of the previous word\n",
    "decoder_states_dynamic = decoder_cell(decoder_emb_dynamic, initial_state = [decoder_init_state_dynamic])\n",
    "decoder_outputs_dynamic = decoder_projection(decoder_states_dynamic)\n",
    "\n",
    "question_model = Model([decoder_inputs_dynamic, decoder_init_state_dynamic], [decoder_outputs_dynamic, decoder_states_dynamic])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### COMPILE TRAINING MODEL ####\n",
    "\n",
    "opti = Adam(lr=0.001)\n",
    "total_model.compile(loss=['sparse_categorical_crossentropy', 'sparse_categorical_crossentropy']\n",
    "                    , optimizer=opti\n",
    "                    , loss_weights = [1,1]) \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_loss_history = []\n",
    "test_loss_history = []\n",
    "\n",
    "EPOCHS = 2000\n",
    "start_epoch = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(start_epoch, start_epoch + EPOCHS + 1):\n",
    "    print(\"Epoch {0}\".format(epoch))\n",
    "    \n",
    "    for i, batch in enumerate(training_data()):\n",
    "        \n",
    "        val_batch = next(test_data_gen, None)\n",
    "        \n",
    "        if val_batch is None:\n",
    "            test_data_gen = test_data()\n",
    "            val_batch = next(test_data_gen, None)\n",
    "            \n",
    "        training_loss = total_model.train_on_batch(\n",
    "            [batch['document_tokens'], batch['question_input_tokens'], batch['answer_masks']]\n",
    "            , [np.expand_dims(batch['answer_labels'], axis = -1), np.expand_dims(batch['question_output_tokens'], axis = -1)]\n",
    "        )\n",
    "        \n",
    "        test_loss = total_model.test_on_batch(\n",
    "            [val_batch['document_tokens'], val_batch['question_input_tokens'], val_batch['answer_masks']]\n",
    "            , [np.expand_dims(val_batch['answer_labels'], axis = -1), np.expand_dims(val_batch['question_output_tokens'], axis = -1)]\n",
    "        )\n",
    "        \n",
    "        training_loss_history.append(training_loss)\n",
    "        test_loss_history.append(test_loss)\n",
    "        \n",
    "        print(\"{}: Train Loss: {} | Test Loss: {}\".format(i, training_loss, test_loss))\n",
    "        \n",
    "    total_model.save_weights(os.path.join(RUN_FOLDER, 'weights/weights_{}.h5'.format(epoch)))\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### SHOW LOSSES ####\n",
    "\n",
    "plt.plot(np.array(training_loss_history)[:,0])\n",
    "plt.plot(np.array(test_loss_history)[:,0])\n",
    "plt.show()\n",
    "        \n",
    "pkl.dump([training_loss_history, test_loss_history], open(os.path.join(RUN_FOLDER, 'weights/histories.pkl'), 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gdl_test",
   "language": "python",
   "name": "gdl_test"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
